SR Infrence¶
use trained weights to super-resolution a new lores depth and hires DEM pair of arbitrary shape.
Also compute perofrmance metrics against a validation depth raster
In [1]:
# Standard library + scientific stack used throughout inference notebook.
from pathlib import Path
import math
import json
import numpy as np
import pandas as pd
import rasterio
import matplotlib.pyplot as plt
In [2]:
# Reuse shared preprocessing and diagnostics helpers from training code.
from t02.train import (
extract_sr_dem_prediction_np,
invert_depth_log1p_np,
normalize_dem,
prepare_sr_dem_model_inputs_np,
resize_bilinear_2d_np,
scale_depth_log1p_np,
)
import t02.results as results
2026-02-06 14:35:25.309164: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2026-02-06 14:35:25.309225: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2026-02-06 14:35:25.310300: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
define raster paths
In [3]:
# Input rasters for DEM-conditioned super-resolution inference.
dem_fp = Path('_inputs/RSSHydro/dudelange/002/DEM.tif')
depth_lores_fp = Path('_inputs/RSSHydro/dudelange/032/ResultA.tif')
depth_hires_valid_fp = Path('_inputs/RSSHydro/dudelange/002/ResultA.tif')
assert dem_fp.exists(), f"DEM file not found: {dem_fp}"
assert depth_lores_fp.exists(), f"Lo-res depth file not found: {depth_lores_fp}"
assert depth_hires_valid_fp.exists(), f"Hi-res validation depth file not found: {depth_hires_valid_fp}"
In [4]:
# Core run configuration:
# - tile geometry
# - windowing strategy
# - preprocessing constants aligned with training
SCALE = 4 # amount of super-resolution to perform. Must be 4 for this model.
LR_TILE = 64
HR_TILE = LR_TILE * SCALE
WINDOW_METHOD = "feather" # "hard" (no overlap) or "feather" (overlap + feather blending)
FEATHER_OVERLAP_LR = LR_TILE // 4 # overlap size at LR scale (used when WINDOW_METHOD == "feather")
MAX_DEPTH = 5.0 # training clip used in SR_ResUNet_DEM.ipynb
DEM_PCT_CLIP = 95.0
DEM_REF_STATS = None # optional train DEM normalization stats loaded from train_config.json
DRY_DEPTH_THRESH_M = results.DEFAULT_DRY_DEPTH_THRESH_M # keep dry/wet threshold consistent with training diagnostics
assert MAX_DEPTH > 0, f"MAX_DEPTH must be > 0; got {MAX_DEPTH}"
assert WINDOW_METHOD in {"hard", "feather"}, f"Unsupported WINDOW_METHOD: {WINDOW_METHOD!r}"
assert 0 <= FEATHER_OVERLAP_LR < LR_TILE, (
f"FEATHER_OVERLAP_LR must be in [0, {LR_TILE}); got {FEATHER_OVERLAP_LR}"
)
if WINDOW_METHOD == "feather":
assert FEATHER_OVERLAP_LR > 0, "FEATHER_OVERLAP_LR must be > 0 when WINDOW_METHOD='feather'"
model weights¶
see t02/SR_ResUNet_DEM.ipynb for training pipeline
In [5]:
# Load inference artifact and preprocessing metadata.
# train.py exports `model_infer.keras` *after* loading best validation-loss weights,
# so this file can be treated as the best checkpoint for inference.
model_fp = Path('t02/train/full/02/model_infer.keras')
assert model_fp.exists(), f"Model file not found: {model_fp}"
train_config_fp = model_fp.parent / 'train_config.json'
if train_config_fp.exists():
train_cfg = json.loads(train_config_fp.read_text())
if 'max_depth' in train_cfg:
MAX_DEPTH = float(train_cfg['max_depth'])
if 'dem_pct_clip' in train_cfg:
DEM_PCT_CLIP = float(train_cfg['dem_pct_clip'])
if 'dem_stats' in train_cfg and train_cfg['dem_stats'] is not None:
DEM_REF_STATS = {k: float(v) for k, v in train_cfg['dem_stats'].items()}
assert MAX_DEPTH > 0, f"MAX_DEPTH must be > 0; got {MAX_DEPTH}"
assert 0 < DEM_PCT_CLIP <= 100, f"DEM_PCT_CLIP must be in (0, 100]; got {DEM_PCT_CLIP}"
if DEM_REF_STATS is not None:
required = {'p_clip', 'dem_min', 'dem_max'}
missing = required.difference(DEM_REF_STATS.keys())
if missing:
raise AssertionError(f"dem_stats is missing keys: {sorted(missing)}")
print(f"Loaded train config from {train_config_fp}")
print(f"Using MAX_DEPTH={MAX_DEPTH}, DEM_PCT_CLIP={DEM_PCT_CLIP}")
if DEM_REF_STATS is not None:
print(f"Using train DEM stats: {DEM_REF_STATS}")
else:
print(f"train_config.json not found at {train_config_fp}; using notebook defaults for preprocessing.")
Loaded train config from t02/train/full/02/train_config.json
Using MAX_DEPTH=5.0, DEM_PCT_CLIP=95.0
Using train DEM stats: {'dem_max': 1036.0579833984375, 'dem_min': 176.46800231933594, 'p_clip': 1036.0579833984375}
Input Checks¶
In [6]:
# Raster I/O and validation helpers.
# These checks catch geospatial mismatches before model inference.
def _assert_square_pixels(src, name):
res = src.res
if not np.isclose(abs(res[0]), abs(res[1])):
raise AssertionError(f"{name} pixels are not square: res={res}")
def _assert_bounds_close(a, b, name_a, name_b, tol=1e-6):
if not all(np.isclose(av, bv, atol=tol) for av, bv in zip(a, b)):
raise AssertionError(f"{name_a} bounds {a} != {name_b} bounds {b}")
def _assert_no_nodata(arr, name):
if np.ma.isMaskedArray(arr) and np.any(arr.mask):
raise AssertionError(f"{name} contains nodata/masked values")
if np.isnan(np.asarray(arr)).any():
raise AssertionError(f"{name} contains NaNs")
def read_and_check_rasters(dem_fp, depth_lores_fp, depth_hires_fp):
with rasterio.open(dem_fp) as dem_src, \
rasterio.open(depth_lores_fp) as lr_src, \
rasterio.open(depth_hires_fp) as hr_src:
# CRS and bounds checks: all inputs must be on the same grid/extent.
if dem_src.crs != hr_src.crs or lr_src.crs != hr_src.crs:
raise AssertionError("CRS mismatch between rasters")
_assert_bounds_close(dem_src.bounds, hr_src.bounds, "DEM", "HR")
_assert_bounds_close(lr_src.bounds, hr_src.bounds, "LR", "HR")
_assert_square_pixels(dem_src, "DEM")
_assert_square_pixels(lr_src, "LR")
_assert_square_pixels(hr_src, "HR")
# Read single-band rasters as masked arrays first to detect nodata explicitly.
dem = dem_src.read(1, masked=True)
lr = lr_src.read(1, masked=True)
hr = hr_src.read(1, masked=True)
_assert_no_nodata(dem, "DEM")
_assert_no_nodata(lr, "LR depth")
_assert_no_nodata(hr, "HR depth")
dem = np.asarray(dem)
lr = np.asarray(lr)
hr = np.asarray(hr)
if np.min(lr) < 0 or np.min(hr) < 0:
raise AssertionError("Depth rasters contain negative values")
if np.max(lr) > 15 or np.max(hr) > 15:
raise AssertionError("Depth rasters exceed 15m max depth")
if np.min(dem) < 0 or np.max(dem) > 5000:
raise AssertionError("DEM raster outside expected range [0, 5000]")
# scale sanity (warn only): model expects a fixed LR->HR ratio.
res_ratio = lr_src.res[0] / hr_src.res[0]
if not np.isclose(res_ratio, SCALE):
print(f"WARNING: LR/HR resolution ratio {res_ratio:.2f} != SCALE={SCALE}. ")
print(" LR will be resampled to match model scale.")
return dem, lr, hr, dem_src.profile, lr_src.profile, hr_src.profile
# Read + validate rasters
dem_raw, lr_raw, hr_raw, dem_profile, lr_profile, hr_profile = read_and_check_rasters(
dem_fp, depth_lores_fp, depth_hires_valid_fp
)
print("Raw shapes (HR, LR):", hr_raw.shape, lr_raw.shape)
WARNING: LR/HR resolution ratio 10.00 != SCALE=4.
LR will be resampled to match model scale.
Raw shapes (HR, LR): (2030, 2090) (203, 209)
plot raw inputs¶
In [7]:
# Quick sanity plots for raw (un-normalized) inputs before preprocessing.
# Input diagnostics (raw): 2 columns (histogram, raster) x 3 rows (LR depth, HR depth, DEM)
plot_specs_raw = [
("LR depth (raw)", lr_raw, "viridis", True, DRY_DEPTH_THRESH_M),
("HR depth (raw)", hr_raw, "viridis", True, DRY_DEPTH_THRESH_M),
("DEM (raw)", dem_raw, "terrain", False, None),
]
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10, 12))
for row_idx, (title, arr, cmap, use_dry_mask, dry_thresh) in enumerate(plot_specs_raw):
arr = np.asarray(arr)
vals = arr[np.isfinite(arr)]
ax_hist = axes[row_idx, 0]
ax_raster = axes[row_idx, 1]
ax_hist.hist(vals, bins=60, color='steelblue', alpha=0.9)
if use_dry_mask:
ax_hist.axvline(dry_thresh, color='red', linestyle='--', linewidth=1.5)
ax_hist.set_title(f"{title} histogram")
ax_hist.set_xlabel('Value')
ax_hist.set_ylabel('Count')
ax_hist.grid(color='lightgrey', linestyle='-', linewidth=0.7)
#add some stats
ax_hist.text(0.98, 0.95, f"shape: {arr.shape}\nmin: {vals.min():.3f}\nmax: {vals.max():.3f}\nmean: {vals.mean():.3f}\nstd: {vals.std():.3f}",
transform=ax_hist.transAxes, fontsize=9, verticalalignment='top', horizontalalignment='right')
raster_arr = np.ma.masked_where(arr < dry_thresh, arr) if use_dry_mask else arr
im = ax_raster.imshow(raster_arr, cmap=cmap)
ax_raster.set_title(f"{title} raster")
ax_raster.set_axis_off()
fig.colorbar(im, ax=ax_raster, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()
pre-process validation raster¶
because our model was trained with a scale of 4, we need need to pre-process to fit model constraints
In [8]:
# Use CPU for deterministic notebook inference/debugging.
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
In [9]:
# Preprocess rasters into model-ready normalized tensors.
# Crop HR/DEM to be divisible by SCALE (avoids edge mismatches).
hr_h, hr_w = hr_raw.shape
crop_h = hr_h - (hr_h % SCALE)
crop_w = hr_w - (hr_w % SCALE)
if (crop_h, crop_w) != (hr_h, hr_w):
print(f"Cropping HR/DEM from {(hr_h, hr_w)} to {(crop_h, crop_w)} for SCALE alignment")
hr_raw = hr_raw[:crop_h, :crop_w]
dem_raw = dem_raw[:crop_h, :crop_w]
# Depth normalization (clip -> log1p scale).
# This must match training exactly so model inputs are in-distribution.
sat_lr = float(np.mean(lr_raw >= MAX_DEPTH))
sat_hr = float(np.mean(hr_raw >= MAX_DEPTH))
print(f"Depth saturation @ MAX_DEPTH: LR={sat_lr:.2%}, HR={sat_hr:.2%}")
lr_raw = scale_depth_log1p_np(lr_raw, max_depth=MAX_DEPTH)
hr_raw = scale_depth_log1p_np(hr_raw, max_depth=MAX_DEPTH)
# DEM normalization: clip negatives, cap upper percentile, then min-max scale.
# Optionally reuse training DEM stats from train_config.json for consistency.
dem_norm, dem_stats = normalize_dem(dem_raw, pct_clip=DEM_PCT_CLIP, ref_stats=DEM_REF_STATS)
assert dem_norm is not None and dem_stats is not None
if DEM_REF_STATS is None:
print("DEM stats computed from inference raster:", dem_stats)
else:
print("DEM stats reused from train_config:", dem_stats)
# Resample LR depth to match HR/SCALE spatial size (bilinear).
# Resulting LR grid aligns with expected model input tile shape.
target_lr_h = crop_h // SCALE
target_lr_w = crop_w // SCALE
lr_norm = resize_bilinear_2d_np(lr_raw, (target_lr_h, target_lr_w), antialias=True)
lr_norm = np.clip(lr_norm, 0.0, 1.0)
assert dem_norm.shape == hr_raw.shape
assert lr_norm.shape == (target_lr_h, target_lr_w)
# Pad to tile multiples for windowed inference.
# Padding avoids partial edge tiles; we crop back after prediction.
pad_h = (int(math.ceil(crop_h / HR_TILE)) * HR_TILE) - crop_h
pad_w = (int(math.ceil(crop_w / HR_TILE)) * HR_TILE) - crop_w
hr_pad = np.pad(hr_raw, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=0.0)
dem_pad = np.pad(dem_norm, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=0.0)
lr_pad = np.pad(
lr_norm,
((0, pad_h // SCALE), (0, pad_w // SCALE)),
mode="constant",
constant_values=0.0,
)
print("Padded shapes HR/LR:", hr_pad.shape, lr_pad.shape)
Cropping HR/DEM from (2030, 2090) to (2028, 2088) for SCALE alignment
Depth saturation @ MAX_DEPTH: LR=0.00%, HR=0.00%
DEM stats reused from train_config: {'p_clip': 1036.0579833984375, 'dem_min': 176.46800231933594, 'dem_max': 1036.0579833984375}
Padded shapes HR/LR: (2048, 2304) (512, 576)
plot normalized inputs¶
In [10]:
# Sanity plots after normalization to verify model-ready ranges/patterns.
# Input diagnostics: 2 columns (histogram, raster) x 3 rows (LR depth, HR depth, DEM)
dry_thresh_norm = float(scale_depth_log1p_np(np.array([DRY_DEPTH_THRESH_M], dtype=np.float32), max_depth=MAX_DEPTH)[0])
plot_specs = [
("LR depth (normalized)", lr_norm, "viridis", True, dry_thresh_norm),
("HR depth (normalized)", hr_raw, "viridis", True, dry_thresh_norm),
("DEM (normalized)", dem_norm, "terrain", False, None),
]
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10, 12))
for row_idx, (title, arr, cmap, use_dry_mask, dry_thresh) in enumerate(plot_specs):
arr = np.asarray(arr)
vals = arr[np.isfinite(arr)]
ax_hist = axes[row_idx, 0]
ax_raster = axes[row_idx, 1]
ax_hist.hist(vals, bins=60, color='steelblue', alpha=0.9)
if use_dry_mask:
ax_hist.axvline(dry_thresh, color='red', linestyle='--', linewidth=1.5)
ax_hist.set_title(f"{title} histogram")
ax_hist.set_xlabel('Value')
ax_hist.set_ylabel('Count')
ax_hist.grid(color='lightgrey', linestyle='-', linewidth=0.7)
#add some stats
ax_hist.text(0.98, 0.95, f"shape: {arr.shape}\nmin: {vals.min():.3f}\nmax: {vals.max():.3f}\nmean: {vals.mean():.3f}\nstd: {vals.std():.3f}",
transform=ax_hist.transAxes, fontsize=9, verticalalignment='top', horizontalalignment='right')
raster_arr = np.ma.masked_where(arr < dry_thresh, arr) if use_dry_mask else arr
im = ax_raster.imshow(raster_arr, cmap=cmap)
ax_raster.set_title(f"{title} raster")
ax_raster.set_axis_off()
fig.colorbar(im, ax=ax_raster, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()
Per-chip inference first, then mosaicing¶
Run simple non-overlap chip inference for diagnostics, render test-chip style plots, then build the final mosaic while reusing cached chip predictions whenever possible.
In [11]:
# Load the exported inference model directly.
# This artifact already contains best validation-loss weights (see train.py export logic).
model = tf.keras.models.load_model(model_fp, compile=False)
model.trainable = False
print(f"Loaded model (best-weight export): {model_fp}")
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA RTX 4000 Ada Generation Laptop GPU, compute capability 8.9
Loaded model (best-weight export): t02/train/full/02/model_infer.keras
In [12]:
# Run per-chip inference first (no mosaicing), then build the full-scene mosaic.
hr_pad_h, hr_pad_w = hr_pad.shape
# Generate tile start positions and force coverage of the trailing edge.
def build_tile_starts(total_size, tile_size, stride):
starts = list(range(0, max(total_size - tile_size + 1, 1), stride))
last_start = total_size - tile_size
if starts[-1] != last_start:
starts.append(last_start)
return starts
# Tile-prediction cache keyed by HR-grid tile origin (y0, x0).
tile_pred_cache = {}
# Predict one chip and cache it so feathered mosaicing can reuse non-overlap results.
def _predict_tile(y0, x0):
key = (int(y0), int(x0))
if key in tile_pred_cache:
return tile_pred_cache[key]
lr_y0 = y0 // SCALE
lr_x0 = x0 // SCALE
lr_tile = lr_pad[lr_y0 : lr_y0 + LR_TILE, lr_x0 : lr_x0 + LR_TILE]
dem_tile = dem_pad[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE]
lr_tile_batched, dem_tile_batched = prepare_sr_dem_model_inputs_np(
depth_lr_norm=lr_tile,
dem_hr_norm=dem_tile,
expected_lr_shape=(LR_TILE, LR_TILE, 1),
expected_hr_shape=(HR_TILE, HR_TILE, 1),
)
pred = model((lr_tile_batched, dem_tile_batched), training=False)
pred_np = extract_sr_dem_prediction_np(pred, expected_hr_shape=(HR_TILE, HR_TILE, 1))
tile_pred_cache[key] = pred_np
return pred_np
In [13]:
# Step 1: simple non-overlap per-chip inference for diagnostics.
nonoverlap_y_starts = list(range(0, hr_pad_h, HR_TILE))
nonoverlap_x_starts = list(range(0, hr_pad_w, HR_TILE))
sr_pad_nonoverlap = np.zeros_like(hr_pad, dtype=np.float32)
print(
f"Running simple non-overlap per-chip inference on "
f"{len(nonoverlap_y_starts) * len(nonoverlap_x_starts)} chips..."
)
for y0 in nonoverlap_y_starts:
for x0 in nonoverlap_x_starts:
pred_np = _predict_tile(y0, x0)
sr_pad_nonoverlap[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE] = pred_np
# Build withheld-test-like chip stacks from fully valid (non-padded) chip locations.
valid_tiles_h = crop_h // HR_TILE
valid_tiles_w = crop_w // HR_TILE
if valid_tiles_h == 0 or valid_tiles_w == 0:
raise ValueError(
f"No fully valid chips for diagnostics (crop={(crop_h, crop_w)}, HR_TILE={HR_TILE})."
)
n_valid = valid_tiles_h * valid_tiles_w
lowres_chips = np.zeros((n_valid, LR_TILE, LR_TILE, 1), dtype=np.float32)
highres_chips = np.zeros((n_valid, HR_TILE, HR_TILE, 1), dtype=np.float32)
preds_chips = np.zeros((n_valid, HR_TILE, HR_TILE, 1), dtype=np.float32)
chip_coords = []
chip_idx = 0
for ty in range(valid_tiles_h):
y0 = ty * HR_TILE
for tx in range(valid_tiles_w):
x0 = tx * HR_TILE
lr_y0 = y0 // SCALE
lr_x0 = x0 // SCALE
lowres_chips[chip_idx, ..., 0] = lr_pad[lr_y0 : lr_y0 + LR_TILE, lr_x0 : lr_x0 + LR_TILE]
highres_chips[chip_idx, ..., 0] = hr_pad[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE]
preds_chips[chip_idx, ..., 0] = tile_pred_cache[(y0, x0)]
chip_coords.append((y0, x0))
chip_idx += 1
chip_coords = np.asarray(chip_coords, dtype=np.int32)
print(f"Prepared {chip_idx} valid chips ({valid_tiles_h} x {valid_tiles_w}) for diagnostics.")
Running simple non-overlap per-chip inference on 72 chips... Prepared 56 valid chips (7 x 8) for diagnostics.
In [14]:
# Compute per-chip summary + per-sample metrics (model vs bilinear baseline).
chip_summary, chip_per_sample = results.evaluate_chip_arrays_vs_bilinear(
lowres_chips=lowres_chips,
highres_chips=highres_chips,
preds_chips=preds_chips,
max_depth=MAX_DEPTH,
split_name='inference_chips',
dry_depth_thresh_m=DRY_DEPTH_THRESH_M,
)
GLOBAL_METRICS = {'test': chip_summary}
PER_SAMPLE_METRICS = {'test': chip_per_sample}
print('Per-chip GLOBAL_METRICS[test]:')
print(json.dumps(GLOBAL_METRICS['test'], indent=2, sort_keys=True))
# Reproduce training_results.ipynb diagnostics on withheld-style chips.
test_samples = PER_SAMPLE_METRICS['test']
fig_scatter, _ = results.plot_metric_scatter_vs_mean_depth(test_samples)
plt.show()
plt.close(fig_scatter)
Per-chip GLOBAL_METRICS[test]:
{
"baseline": {
"CSI": 0.21588961780071259,
"MAE": 0.017262272536754608,
"PSNR": 27.65399169921875,
"RMSE": 0.0518011711537838,
"RMSE_wet": 0.42853009700775146,
"SSIM": 0.8069341778755188
},
"best_epoch": {
"CSI": 0.21867890655994415,
"MAE": 0.01664661057293415,
"PSNR": 27.6664981842041,
"RMSE": 0.05031589791178703,
"RMSE_wet": 0.45040664076805115,
"SSIM": 0.8093225359916687
}
}
In [15]:
fig_chip_scatter, _ = results.plot_chip_stat_scatter(test_samples)
plt.show()
plt.close(fig_chip_scatter)
In [16]:
_ = results.plot_best_worst_chip_examples(
lowres_chips=lowres_chips,
highres_chips=highres_chips,
preds_chips=preds_chips,
max_depth=MAX_DEPTH,
n_show=3,
dry_depth_thresh_m=DRY_DEPTH_THRESH_M,
cmap='cividis',
chip_ids=chip_coords,
)
Scanned test chips: 56 Retained candidate chips in memory: 6 Worst chips (highest SR MAE):
Best chips (lowest SR MAE):
/usr/local/lib/python3.11/dist-packages/numpy/lib/histograms.py:885: RuntimeWarning: divide by zero encountered in divide return n/db/n.sum(), bin_edges /usr/local/lib/python3.11/dist-packages/numpy/lib/histograms.py:885: RuntimeWarning: invalid value encountered in divide return n/db/n.sum(), bin_edges
In [17]:
# Step 2: build final mosaic, reusing cached chip predictions whenever possible.
if WINDOW_METHOD == 'hard':
sr_pad = sr_pad_nonoverlap.copy()
print(
f"Mosaicing with WINDOW_METHOD='hard' using cached chips only "
f"({len(tile_pred_cache)} predictions cached)."
)
elif WINDOW_METHOD == 'feather':
overlap_hr = FEATHER_OVERLAP_LR * SCALE
stride_hr = HR_TILE - overlap_hr
if stride_hr <= 0:
raise AssertionError(
f"Feather stride must be > 0; got stride_hr={stride_hr} from FEATHER_OVERLAP_LR={FEATHER_OVERLAP_LR}"
)
y_starts = build_tile_starts(hr_pad_h, HR_TILE, stride_hr)
x_starts = build_tile_starts(hr_pad_w, HR_TILE, stride_hr)
feather_1d = np.ones(HR_TILE, dtype=np.float32)
if overlap_hr > 0:
ramp = np.linspace(0.0, 1.0, overlap_hr + 2, dtype=np.float32)[1:-1]
feather_1d[:overlap_hr] = ramp
feather_1d[-overlap_hr:] = ramp[::-1]
feather_1d = np.clip(feather_1d, 1e-3, 1.0)
accum = np.zeros_like(hr_pad, dtype=np.float32)
weight_sum = np.zeros_like(hr_pad, dtype=np.float32)
print(
f"Mosaicing with WINDOW_METHOD='feather' using {len(y_starts) * len(x_starts)} windows "
f"(overlap={overlap_hr} px, stride={stride_hr} px)..."
)
for yi, y0 in enumerate(y_starts):
for xi, x0 in enumerate(x_starts):
pred_np = _predict_tile(y0, x0)
wy = feather_1d.copy()
wx = feather_1d.copy()
if yi == 0:
wy[:overlap_hr] = 1.0
if yi == len(y_starts) - 1:
wy[-overlap_hr:] = 1.0
if xi == 0:
wx[:overlap_hr] = 1.0
if xi == len(x_starts) - 1:
wx[-overlap_hr:] = 1.0
weight = np.outer(wy, wx).astype(np.float32, copy=False)
accum[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE] += pred_np * weight
weight_sum[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE] += weight
sr_pad = np.divide(
accum,
np.maximum(weight_sum, 1e-6),
out=np.zeros_like(accum),
where=weight_sum > 0,
)
print(f"Cached predictions after feather mosaicing: {len(tile_pred_cache)}")
else:
raise ValueError(f"Unsupported WINDOW_METHOD: {WINDOW_METHOD!r}")
# Remove padding and keep normalized prediction range bounded.
sr = np.clip(sr_pad[:crop_h, :crop_w], 0.0, 1.0)
hr_valid = hr_raw[:crop_h, :crop_w]
print('SR shape:', sr.shape)
Mosaicing with WINDOW_METHOD='feather' using 132 windows (overlap=64 px, stride=192 px)... Cached predictions after feather mosaicing: 188 SR shape: (2028, 2088)
In [18]:
# Export prediction as GeoTIFF in depth units (meters).
ofp = Path('t02/inference/sr_result.tif')
ofp.parent.mkdir(parents=True, exist_ok=True)
# Convert model output from normalized log-space back to depth (meters).
sr_depth_m = invert_depth_log1p_np(sr, max_depth=MAX_DEPTH)
print(f"Saving SR result to {ofp}...")
out_profile = hr_profile.copy()
out_profile.update(
dtype=rasterio.float32,
count=1,
compress='deflate',
height=sr_depth_m.shape[0],
width=sr_depth_m.shape[1],
)
with rasterio.open(ofp, 'w', **out_profile) as dst:
dst.write(sr_depth_m.astype(np.float32), 1)
print(
f"Saved depth raster in meters. Range: min={float(np.nanmin(sr_depth_m)):.4f}, max={float(np.nanmax(sr_depth_m)):.4f}"
)
Saving SR result to t02/inference/sr_result.tif... Saved depth raster in meters. Range: min=0.0000, max=4.4980
Bilinear baseline + mosaic-level metrics¶
In [19]:
# Build a bilinear baseline from LR input and compare against HR target.
baseline_pad = resize_bilinear_2d_np(lr_pad, (hr_pad_h, hr_pad_w), antialias=True)
baseline = np.clip(baseline_pad[:crop_h, :crop_w], 0.0, 1.0)
# Compute mosaic-level metrics in normalized space with shared helpers.
hr_full = tf.convert_to_tensor(hr_valid[None, ..., None], dtype=tf.float32)
sr_full = tf.convert_to_tensor(sr[None, ..., None], dtype=tf.float32)
bl_full = tf.convert_to_tensor(baseline[None, ..., None], dtype=tf.float32)
sr_metric_tensors = results.compute_per_sample_metrics(hr_full, sr_full)
bl_metric_tensors = results.compute_per_sample_metrics(hr_full, bl_full)
metrics_sr = results.reduce_metric_buffers({k: [v] for k, v in sr_metric_tensors.items()})
metrics_bilinear = results.reduce_metric_buffers({k: [v] for k, v in bl_metric_tensors.items()})
df = pd.DataFrame({
'ResUNet': metrics_sr,
'Bilinear': metrics_bilinear,
})
df = df.loc[list(results.METRIC_KEYS), ['ResUNet', 'Bilinear']]
df.round(4)
Out[19]:
| ResUNet | Bilinear | |
|---|---|---|
| MAE | 0.0151 | 0.0156 |
| PSNR | 25.2495 | 24.4981 |
| SSIM | 0.8264 | 0.8244 |
| RMSE | 0.0546 | 0.0596 |
| RMSE_wet | 0.4095 | 0.3185 |
| CSI | 0.2944 | 0.2879 |
PLOT Inference (mosaic-level)¶
In [20]:
# Final full-scene diagnostics in depth units (meters).
# Plot final full-scene inference with shared depth-domain metrics and histograms.
DRY_DEPTH_THRESH_PLOT_M = DRY_DEPTH_THRESH_M
full_lr = tf.convert_to_tensor(lr_norm[..., None], dtype=tf.float32)
full_hr = tf.convert_to_tensor(hr_valid[..., None], dtype=tf.float32)
full_sr = tf.convert_to_tensor(sr[..., None], dtype=tf.float32)
print("Full-scene inference diagnostics")
fig, final_metrics = results.plot_chip_comparison(
highres=full_hr,
lowres=full_lr,
preds=full_sr,
max_depth=MAX_DEPTH,
dry_depth_thresh_m=DRY_DEPTH_THRESH_PLOT_M,
cmap="cividis",
)
plt.show()
plt.close(fig)
tile_label = "full-scene"
print("PSNR between LR and HR image {}: {:.4f}".format(tile_label, final_metrics["lr_psnr"]))
print("SSIM between LR and HR image {}: {:.4f}".format(tile_label, final_metrics["lr_ssim"]))
print("PSNR between HR and SR image {}: {:.4f}".format(tile_label, final_metrics["sr_psnr"]))
print("SSIM between HR and SR image {}: {:.4f}".format(tile_label, final_metrics["sr_ssim"]))
print("MAE between HR and SR image {}: {:.6f} m".format(tile_label, final_metrics["sr_mae_m"]))
final_metrics
Full-scene inference diagnostics
PSNR between LR and HR image full-scene: 28.4491 SSIM between LR and HR image full-scene: 0.8828 PSNR between HR and SR image full-scene: 29.0922 SSIM between HR and SR image full-scene: 0.8889 MAE between HR and SR image full-scene: 0.037980 m
Out[20]:
{'lr_psnr': 28.449148178100586,
'lr_ssim': 0.8827940225601196,
'lr_mae_m': 0.04028255119919777,
'lr_rmse_m': 0.18902209401130676,
'lr_rmse_wet_m': 0.5304687023162842,
'lr_bias_m': 0.010088089853525162,
'lr_wet_pixel_count': 336221,
'lr_dry_pixel_count': 3898243,
'sr_psnr': 29.0921688079834,
'sr_ssim': 0.8889098763465881,
'sr_mae_m': 0.03798045590519905,
'sr_rmse_m': 0.17553409934043884,
'sr_rmse_wet_m': 0.5603041052818298,
'sr_bias_m': -0.0012882208684459329,
'sr_wet_pixel_count': 336221,
'sr_dry_pixel_count': 3898243,
'hr_wet_pixel_count': 336221,
'hr_dry_pixel_count': 3898243,
'hr_mean_depth_m': 0.03274461627006531,
'hr_max_depth_m': 5.0,
'hr_min_depth_m': 0.0,
'bl_psnr': 28.636943817138672,
'bl_ssim': 0.8845934271812439,
'bl_mae_m': 0.039947520941495895,
'bl_rmse_m': 0.1849791556596756,
'bl_rmse_wet_m': 0.5248517990112305,
'bl_bias_m': 0.009580918587744236,
'bl_wet_pixel_count': 336221,
'bl_dry_pixel_count': 3898243}